
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import math
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
from torch import autograd
import torch.nn.functional as F
import numpy as np
from torch.autograd.function import InplaceFunction



class Conv2d_SparseNeuralGrad(nn.Conv2d):
    """docstring for Conv2d_BF16."""

    def __init__(self, *args, **kwargs):
        super(Conv2d_SparseNeuralGrad, self).__init__(*args, **kwargs)
        self.fullName = ''
        N=1
        M=2
        self.register_buffer('N',torch.tensor([N]))
        self.register_buffer('M',torch.tensor([M]))
        self.iterCnt = 0

    def forward(self, input):
        output = F.conv2d(input, self.weight, self.bias, self.stride,
                 self.padding, self.dilation, self.groups)

        if self.iterCnt >= 100 and 'layer' in self.fullName: #all layers except first
            output = Conv2dFunctionSparseNeuralGradMult.apply(output, self.N,self.M)

        return output


class Conv2dFunctionSparseNeuralGradMult(autograd.Function):
    @staticmethod
    def forward(ctx, input, N=2,M=4):
        ctx.save_for_backward(N,M)
        return input

    @staticmethod
    def backward(ctx, grad_output):
        # Load saved tensors
        N,M = ctx.saved_variables

        xp = grad_output.permute(1,2,3,0)
        C,k1,k2,B = xp.shape

        xd = xp.reshape(-1, M)

        if N == 1:

            ######1:2#######
            out = torch.zeros(xd.shape, dtype=xd.dtype,device = xd.device)
            mask = torch.ones(xp.shape, dtype=xp.dtype,device = xd.device)
            eps = torch.tensor([1e-30],dtype=xd.dtype,device=xd.device) #to avoid all zero in same line
            idxs = torch.multinomial(xd.abs() + eps,N.item(),replacement=False)
            xp = torch.sign(xd) * out.scatter(1,idxs[:,0].unsqueeze(1), torch.sum(xd.abs(),dim=1,keepdim=True))

        else:
            # Approx 2:4
            eps = torch.tensor([1e-30], dtype=xd.dtype, device=xd.device)  # to avoid all zero in same line
            sums = torch.sum(xd.abs() , dim=1, keepdim=True)
            axs = sums / (sums - xd.abs() + eps)
            axs[axs == float("Inf")] = 1 / eps
            scale = 1 / ((xd.abs() / (sums+ eps) ) * (1 - (M - 1) + torch.sum(axs, dim=1, keepdim=True) - axs)) #+ eps
            scale[scale == float("Inf")] = float(0)
            idx = torch.multinomial(xd.abs() + eps, N.item(), replacement=False)
            mask = torch.zeros(xd.shape, device=xd.device).scatter(1, idx, 1) * scale
            mask = mask.reshape(xp.shape)

        xp = xp.reshape(C,-1,M)
        out = xp.reshape(C,k1,k2,B)*mask
        out = out.permute(3,0,1,2)
        return out,None,None


class SparseAct(nn.Module):
    def __init__(self,inplace=True, N=2, M=4,relu=True,trainable=False, **kwargs):
        super(SparseAct, self).__init__()
        self.register_buffer('N',torch.tensor([N]))
        self.register_buffer('M',torch.tensor([M]))
        self.inplace = inplace
        self.biasFixEnable = True
    def forward(self, input):
        input = F.relu(input,inplace=self.inplace)

        xp = input.permute(0,2,3,1)
        Co,Ci,k1,k2 = xp.shape

        xd = xp.reshape(-1, self.M)

        index = torch.argsort(xd, dim=1)[:, :int(self.M-self.N)]
        mask = torch.ones(xd.shape, device=xd.device)
        mask = mask.scatter_(dim=1, index=index, value=0).reshape(xp.shape)
        xp = xp.reshape(Co, -1, self.M)
        out=xp.reshape(Co,Ci,k1,k2)*mask.detach()
        out = out.permute(0,3,1,2)


        return out




class UniformQuantizeSawb(InplaceFunction):

    @staticmethod
    def forward(ctx, input,c1,c2,Qp, Qn ):
        output = input.clone()
        with torch.no_grad():
            clip = (c1*torch.sqrt(torch.mean(input**2))) - (c2*torch.mean(input.abs()))
            scale = 2*clip / (Qp - Qn)
            output.div_(scale)

            output.clamp_(Qn, Qp).round_()
            output.mul_(scale)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        # straight-through estimator
        grad_input = grad_output
        return grad_input, None, None, None, None